"""Helpers for introspecting and wrapping annotations."""

import ast
import enum
import functools
import sys
import types

__all__ = [
    "Format",
    "ForwardRef",
    "call_annotate_function",
    "call_evaluate_function",
    "get_annotate_function",
    "get_annotations",
]


class Format(enum.IntEnum):
    VALUE = 1
    FORWARDREF = 2
    SOURCE = 3


_Union = None
_sentinel = object()

# Slots shared by ForwardRef and _Stringifier. The __forward__ names must be
# preserved for compatibility with the old typing.ForwardRef class. The remaining
# names are private.
_SLOTS = (
    "__forward_evaluated__",
    "__forward_value__",
    "__forward_is_argument__",
    "__forward_is_class__",
    "__forward_module__",
    "__weakref__",
    "__arg__",
    "__ast_node__",
    "__code__",
    "__globals__",
    "__owner__",
    "__cell__",
)


class ForwardRef:
    """Wrapper that holds a forward reference."""

    __slots__ = _SLOTS

    def __init__(
        self,
        arg,
        *,
        module=None,
        owner=None,
        is_argument=True,
        is_class=False,
        _globals=None,
        _cell=None,
    ):
        if not isinstance(arg, str):
            raise TypeError(f"Forward reference must be a string -- got {arg!r}")

        self.__arg__ = arg
        self.__forward_evaluated__ = False
        self.__forward_value__ = None
        self.__forward_is_argument__ = is_argument
        self.__forward_is_class__ = is_class
        self.__forward_module__ = module
        self.__code__ = None
        self.__ast_node__ = None
        self.__globals__ = _globals
        self.__cell__ = _cell
        self.__owner__ = owner

    def __init_subclass__(cls, /, *args, **kwds):
        raise TypeError("Cannot subclass ForwardRef")

    def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
        """Evaluate the forward reference and return the value.

        If the forward reference cannot be evaluated, raise an exception.
        """
        if self.__forward_evaluated__:
            return self.__forward_value__
        if self.__cell__ is not None:
            try:
                value = self.__cell__.cell_contents
            except ValueError:
                pass
            else:
                self.__forward_evaluated__ = True
                self.__forward_value__ = value
                return value
        if owner is None:
            owner = self.__owner__

        if globals is None and self.__forward_module__ is not None:
            globals = getattr(
                sys.modules.get(self.__forward_module__, None), "__dict__", None
            )
        if globals is None:
            globals = self.__globals__
        if globals is None:
            if isinstance(owner, type):
                module_name = getattr(owner, "__module__", None)
                if module_name:
                    module = sys.modules.get(module_name, None)
                    if module:
                        globals = getattr(module, "__dict__", None)
            elif isinstance(owner, types.ModuleType):
                globals = getattr(owner, "__dict__", None)
            elif callable(owner):
                globals = getattr(owner, "__globals__", None)

        if locals is None:
            locals = {}
            if isinstance(owner, type):
                locals.update(vars(owner))

        if type_params is None and owner is not None:
            # "Inject" type parameters into the local namespace
            # (unless they are shadowed by assignments *in* the local namespace),
            # as a way of emulating annotation scopes when calling `eval()`
            type_params = getattr(owner, "__type_params__", None)

        # type parameters require some special handling,
        # as they exist in their own scope
        # but `eval()` does not have a dedicated parameter for that scope.
        # For classes, names in type parameter scopes should override
        # names in the global scope (which here are called `localns`!),
        # but should in turn be overridden by names in the class scope
        # (which here are called `globalns`!)
        if type_params is not None:
            if globals is None:
                globals = {}
            else:
                globals = dict(globals)
            if locals is None:
                locals = {}
            else:
                locals = dict(locals)
            for param in type_params:
                param_name = param.__name__
                if not self.__forward_is_class__ or param_name not in globals:
                    globals[param_name] = param
                    locals.pop(param_name, None)

        code = self.__forward_code__
        value = eval(code, globals=globals, locals=locals)
        self.__forward_evaluated__ = True
        self.__forward_value__ = value
        return value

    def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard):
        import typing
        import warnings

        if type_params is _sentinel:
            typing._deprecation_warning_for_no_type_params_passed(
                "typing.ForwardRef._evaluate"
            )
            type_params = ()
        warnings._deprecated(
            "ForwardRef._evaluate",
            "{name} is a private API and is retained for compatibility, but will be removed"
            " in Python 3.16. Use ForwardRef.evaluate() or typing.evaluate_forward_ref() instead.",
            remove=(3, 16),
        )
        return typing.evaluate_forward_ref(
            self,
            globals=globalns,
            locals=localns,
            type_params=type_params,
            _recursive_guard=recursive_guard,
        )

    @property
    def __forward_arg__(self):
        if self.__arg__ is not None:
            return self.__arg__
        if self.__ast_node__ is not None:
            self.__arg__ = ast.unparse(self.__ast_node__)
            return self.__arg__
        raise AssertionError(
            "Attempted to access '__forward_arg__' on an uninitialized ForwardRef"
        )

    @property
    def __forward_code__(self):
        if self.__code__ is not None:
            return self.__code__
        arg = self.__forward_arg__
        # If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`.
        # Unfortunately, this isn't a valid expression on its own, so we
        # do the unpacking manually.
        if arg.startswith("*"):
            arg_to_compile = f"({arg},)[0]"  # E.g. (*Ts,)[0] or (*tuple[int, int],)[0]
        else:
            arg_to_compile = arg
        try:
            self.__code__ = compile(arg_to_compile, "<string>", "eval")
        except SyntaxError:
            raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}")
        return self.__code__

    def __eq__(self, other):
        if not isinstance(other, ForwardRef):
            return NotImplemented
        if self.__forward_evaluated__ and other.__forward_evaluated__:
            return (
                self.__forward_arg__ == other.__forward_arg__
                and self.__forward_value__ == other.__forward_value__
            )
        return (
            self.__forward_arg__ == other.__forward_arg__
            and self.__forward_module__ == other.__forward_module__
        )

    def __hash__(self):
        return hash((self.__forward_arg__, self.__forward_module__))

    def __or__(self, other):
        global _Union
        if _Union is None:
            from typing import Union as _Union
        return _Union[self, other]

    def __ror__(self, other):
        global _Union
        if _Union is None:
            from typing import Union as _Union
        return _Union[other, self]

    def __repr__(self):
        if self.__forward_module__ is None:
            module_repr = ""
        else:
            module_repr = f", module={self.__forward_module__!r}"
        return f"ForwardRef({self.__forward_arg__!r}{module_repr})"


class _Stringifier:
    # Must match the slots on ForwardRef, so we can turn an instance of one into an
    # instance of the other in place.
    __slots__ = _SLOTS

    def __init__(self, node, globals=None, owner=None, is_class=False, cell=None):
        assert isinstance(node, ast.AST)
        self.__arg__ = None
        self.__forward_evaluated__ = False
        self.__forward_value__ = None
        self.__forward_is_argument__ = False
        self.__forward_is_class__ = is_class
        self.__forward_module__ = None
        self.__code__ = None
        self.__ast_node__ = node
        self.__globals__ = globals
        self.__cell__ = cell
        self.__owner__ = owner

    def __convert(self, other):
        if isinstance(other, _Stringifier):
            return other.__ast_node__
        elif isinstance(other, slice):
            return ast.Slice(
                lower=self.__convert(other.start) if other.start is not None else None,
                upper=self.__convert(other.stop) if other.stop is not None else None,
                step=self.__convert(other.step) if other.step is not None else None,
            )
        else:
            return ast.Constant(value=other)

    def __make_new(self, node):
        return _Stringifier(
            node, self.__globals__, self.__owner__, self.__forward_is_class__
        )

    # Must implement this since we set __eq__. We hash by identity so that
    # stringifiers in dict keys are kept separate.
    def __hash__(self):
        return id(self)

    def __getitem__(self, other):
        # Special case, to avoid stringifying references to class-scoped variables
        # as '__classdict__["x"]'.
        if (
            isinstance(self.__ast_node__, ast.Name)
            and self.__ast_node__.id == "__classdict__"
        ):
            raise KeyError
        if isinstance(other, tuple):
            elts = [self.__convert(elt) for elt in other]
            other = ast.Tuple(elts)
        else:
            other = self.__convert(other)
        assert isinstance(other, ast.AST), repr(other)
        return self.__make_new(ast.Subscript(self.__ast_node__, other))

    def __getattr__(self, attr):
        return self.__make_new(ast.Attribute(self.__ast_node__, attr))

    def __call__(self, *args, **kwargs):
        return self.__make_new(
            ast.Call(
                self.__ast_node__,
                [self.__convert(arg) for arg in args],
                [
                    ast.keyword(key, self.__convert(value))
                    for key, value in kwargs.items()
                ],
            )
        )

    def __iter__(self):
        yield self.__make_new(ast.Starred(self.__ast_node__))

    def __repr__(self):
        return ast.unparse(self.__ast_node__)

    def __format__(self, format_spec):
        raise TypeError("Cannot stringify annotation containing string formatting")

    def _make_binop(op: ast.AST):
        def binop(self, other):
            return self.__make_new(
                ast.BinOp(self.__ast_node__, op, self.__convert(other))
            )

        return binop

    __add__ = _make_binop(ast.Add())
    __sub__ = _make_binop(ast.Sub())
    __mul__ = _make_binop(ast.Mult())
    __matmul__ = _make_binop(ast.MatMult())
    __truediv__ = _make_binop(ast.Div())
    __mod__ = _make_binop(ast.Mod())
    __lshift__ = _make_binop(ast.LShift())
    __rshift__ = _make_binop(ast.RShift())
    __or__ = _make_binop(ast.BitOr())
    __xor__ = _make_binop(ast.BitXor())
    __and__ = _make_binop(ast.BitAnd())
    __floordiv__ = _make_binop(ast.FloorDiv())
    __pow__ = _make_binop(ast.Pow())

    del _make_binop

    def _make_rbinop(op: ast.AST):
        def rbinop(self, other):
            return self.__make_new(
                ast.BinOp(self.__convert(other), op, self.__ast_node__)
            )

        return rbinop

    __radd__ = _make_rbinop(ast.Add())
    __rsub__ = _make_rbinop(ast.Sub())
    __rmul__ = _make_rbinop(ast.Mult())
    __rmatmul__ = _make_rbinop(ast.MatMult())
    __rtruediv__ = _make_rbinop(ast.Div())
    __rmod__ = _make_rbinop(ast.Mod())
    __rlshift__ = _make_rbinop(ast.LShift())
    __rrshift__ = _make_rbinop(ast.RShift())
    __ror__ = _make_rbinop(ast.BitOr())
    __rxor__ = _make_rbinop(ast.BitXor())
    __rand__ = _make_rbinop(ast.BitAnd())
    __rfloordiv__ = _make_rbinop(ast.FloorDiv())
    __rpow__ = _make_rbinop(ast.Pow())

    del _make_rbinop

    def _make_compare(op):
        def compare(self, other):
            return self.__make_new(
                ast.Compare(
                    left=self.__ast_node__,
                    ops=[op],
                    comparators=[self.__convert(other)],
                )
            )

        return compare

    __lt__ = _make_compare(ast.Lt())
    __le__ = _make_compare(ast.LtE())
    __eq__ = _make_compare(ast.Eq())
    __ne__ = _make_compare(ast.NotEq())
    __gt__ = _make_compare(ast.Gt())
    __ge__ = _make_compare(ast.GtE())

    del _make_compare

    def _make_unary_op(op):
        def unary_op(self):
            return self.__make_new(ast.UnaryOp(op, self.__ast_node__))

        return unary_op

    __invert__ = _make_unary_op(ast.Invert())
    __pos__ = _make_unary_op(ast.UAdd())
    __neg__ = _make_unary_op(ast.USub())

    del _make_unary_op


class _StringifierDict(dict):
    def __init__(self, namespace, globals=None, owner=None, is_class=False):
        super().__init__(namespace)
        self.namespace = namespace
        self.globals = globals
        self.owner = owner
        self.is_class = is_class
        self.stringifiers = []

    def __missing__(self, key):
        fwdref = _Stringifier(
            ast.Name(id=key),
            globals=self.globals,
            owner=self.owner,
            is_class=self.is_class,
        )
        self.stringifiers.append(fwdref)
        return fwdref


def call_evaluate_function(evaluate, format, *, owner=None):
    """Call an evaluate function. Evaluate functions are normally generated for
    the value of type aliases and the bounds, constraints, and defaults of
    type parameter objects.
    """
    return call_annotate_function(evaluate, format, owner=owner, _is_evaluate=True)


def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
    """Call an __annotate__ function. __annotate__ functions are normally
    generated by the compiler to defer the evaluation of annotations. They
    can be called with any of the format arguments in the Format enum, but
    compiler-generated __annotate__ functions only support the VALUE format.
    This function provides additional functionality to call __annotate__
    functions with the FORWARDREF and SOURCE formats.

    *annotate* must be an __annotate__ function, which takes a single argument
    and returns a dict of annotations.

    *format* must be a member of the Format enum or one of the corresponding
    integer values.

    *owner* can be the object that owns the annotations (i.e., the module,
    class, or function that the __annotate__ function derives from). With the
    FORWARDREF format, it is used to provide better evaluation capabilities
    on the generated ForwardRef objects.

    """
    try:
        return annotate(format)
    except NotImplementedError:
        pass
    if format == Format.SOURCE:
        # SOURCE is implemented by calling the annotate function in a special
        # environment where every name lookup results in an instance of _Stringifier.
        # _Stringifier supports every dunder operation and returns a new _Stringifier.
        # At the end, we get a dictionary that mostly contains _Stringifier objects (or
        # possibly constants if the annotate function uses them directly). We then
        # convert each of those into a string to get an approximation of the
        # original source.
        globals = _StringifierDict({})
        if annotate.__closure__:
            freevars = annotate.__code__.co_freevars
            new_closure = []
            for i, cell in enumerate(annotate.__closure__):
                if i < len(freevars):
                    name = freevars[i]
                else:
                    name = "__cell__"
                fwdref = _Stringifier(ast.Name(id=name))
                new_closure.append(types.CellType(fwdref))
            closure = tuple(new_closure)
        else:
            closure = None
        func = types.FunctionType(
            annotate.__code__,
            globals,
            closure=closure,
            argdefs=annotate.__defaults__,
            kwdefaults=annotate.__kwdefaults__,
        )
        annos = func(Format.VALUE)
        if _is_evaluate:
            return annos if isinstance(annos, str) else repr(annos)
        return {
            key: val if isinstance(val, str) else repr(val)
            for key, val in annos.items()
        }
    elif format == Format.FORWARDREF:
        # FORWARDREF is implemented similarly to SOURCE, but there are two changes,
        # at the beginning and the end of the process.
        # First, while SOURCE uses an empty dictionary as the namespace, so that all
        # name lookups result in _Stringifier objects, FORWARDREF uses the globals
        # and builtins, so that defined names map to their real values.
        # Second, instead of returning strings, we want to return either real values
        # or ForwardRef objects. To do this, we keep track of all _Stringifier objects
        # created while the annotation is being evaluated, and at the end we convert
        # them all to ForwardRef objects by assigning to __class__. To make this
        # technique work, we have to ensure that the _Stringifier and ForwardRef
        # classes share the same attributes.
        # We use this technique because while the annotations are being evaluated,
        # we want to support all operations that the language allows, including even
        # __getattr__ and __eq__, and return new _Stringifier objects so we can accurately
        # reconstruct the source. But in the dictionary that we eventually return, we
        # want to return objects with more user-friendly behavior, such as an __eq__
        # that returns a bool and an defined set of attributes.
        namespace = {**annotate.__builtins__, **annotate.__globals__}
        is_class = isinstance(owner, type)
        globals = _StringifierDict(namespace, annotate.__globals__, owner, is_class)
        if annotate.__closure__:
            freevars = annotate.__code__.co_freevars
            new_closure = []
            for i, cell in enumerate(annotate.__closure__):
                try:
                    cell.cell_contents
                except ValueError:
                    if i < len(freevars):
                        name = freevars[i]
                    else:
                        name = "__cell__"
                    fwdref = _Stringifier(
                        ast.Name(id=name),
                        cell=cell,
                        owner=owner,
                        globals=annotate.__globals__,
                        is_class=is_class,
                    )
                    globals.stringifiers.append(fwdref)
                    new_closure.append(types.CellType(fwdref))
                else:
                    new_closure.append(cell)
            closure = tuple(new_closure)
        else:
            closure = None
        func = types.FunctionType(
            annotate.__code__,
            globals,
            closure=closure,
            argdefs=annotate.__defaults__,
            kwdefaults=annotate.__kwdefaults__,
        )
        result = func(Format.VALUE)
        for obj in globals.stringifiers:
            obj.__class__ = ForwardRef
        return result
    elif format == Format.VALUE:
        # Should be impossible because __annotate__ functions must not raise
        # NotImplementedError for this format.
        raise RuntimeError("annotate function does not support VALUE format")
    else:
        raise ValueError(f"Invalid format: {format!r}")


# We use the descriptors from builtins.type instead of accessing
# .__annotations__ and .__annotate__ directly on class objects, because
# otherwise we could get wrong results in some cases involving metaclasses.
# See PEP 749.
_BASE_GET_ANNOTATE = type.__dict__["__annotate__"].__get__
_BASE_GET_ANNOTATIONS = type.__dict__["__annotations__"].__get__


def get_annotate_function(obj):
    """Get the __annotate__ function for an object.

    obj may be a function, class, or module, or a user-defined type with
    an `__annotate__` attribute.

    Returns the __annotate__ function or None.
    """
    if isinstance(obj, type):
        return _BASE_GET_ANNOTATE(obj)
    return getattr(obj, "__annotate__", None)


def get_annotations(
    obj, *, globals=None, locals=None, eval_str=False, format=Format.VALUE
):
    """Compute the annotations dict for an object.

    obj may be a callable, class, or module.
    Passing in an object of any other type raises TypeError.

    Returns a dict.  get_annotations() returns a new dict every time
    it's called; calling it twice on the same object will return two
    different but equivalent dicts.

    This function handles several details for you:

      * If eval_str is true, values of type str will
        be un-stringized using eval().  This is intended
        for use with stringized annotations
        ("from __future__ import annotations").
      * If obj doesn't have an annotations dict, returns an
        empty dict.  (Functions and methods always have an
        annotations dict; classes, modules, and other types of
        callables may not.)
      * Ignores inherited annotations on classes.  If a class
        doesn't have its own annotations dict, returns an empty dict.
      * All accesses to object members and dict values are done
        using getattr() and dict.get() for safety.
      * Always, always, always returns a freshly-created dict.

    eval_str controls whether or not values of type str are replaced
    with the result of calling eval() on those values:

      * If eval_str is true, eval() is called on values of type str.
      * If eval_str is false (the default), values of type str are unchanged.

    globals and locals are passed in to eval(); see the documentation
    for eval() for more information.  If either globals or locals is
    None, this function may replace that value with a context-specific
    default, contingent on type(obj):

      * If obj is a module, globals defaults to obj.__dict__.
      * If obj is a class, globals defaults to
        sys.modules[obj.__module__].__dict__ and locals
        defaults to the obj class namespace.
      * If obj is a callable, globals defaults to obj.__globals__,
        although if obj is a wrapped function (using
        functools.update_wrapper()) it is first unwrapped.
    """
    if eval_str and format != Format.VALUE:
        raise ValueError("eval_str=True is only supported with format=Format.VALUE")

    # For VALUE format, we look at __annotations__ directly.
    if format != Format.VALUE:
        annotate = get_annotate_function(obj)
        if annotate is not None:
            ann = call_annotate_function(annotate, format, owner=obj)
            if not isinstance(ann, dict):
                raise ValueError(f"{obj!r}.__annotate__ returned a non-dict")
            return dict(ann)

    if isinstance(obj, type):
        try:
            ann = _BASE_GET_ANNOTATIONS(obj)
        except AttributeError:
            # For static types, the descriptor raises AttributeError.
            return {}
    else:
        ann = getattr(obj, "__annotations__", None)
        if ann is None:
            return {}

    if not isinstance(ann, dict):
        raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")

    if not ann:
        return {}

    if not eval_str:
        return dict(ann)

    if isinstance(obj, type):
        # class
        obj_globals = None
        module_name = getattr(obj, "__module__", None)
        if module_name:
            module = sys.modules.get(module_name, None)
            if module:
                obj_globals = getattr(module, "__dict__", None)
        obj_locals = dict(vars(obj))
        unwrap = obj
    elif isinstance(obj, types.ModuleType):
        # module
        obj_globals = getattr(obj, "__dict__")
        obj_locals = None
        unwrap = None
    elif callable(obj):
        # this includes types.Function, types.BuiltinFunctionType,
        # types.BuiltinMethodType, functools.partial, functools.singledispatch,
        # "class funclike" from Lib/test/test_inspect... on and on it goes.
        obj_globals = getattr(obj, "__globals__", None)
        obj_locals = None
        unwrap = obj
    elif ann is not None:
        obj_globals = obj_locals = unwrap = None
    else:
        raise TypeError(f"{obj!r} is not a module, class, or callable.")

    if unwrap is not None:
        while True:
            if hasattr(unwrap, "__wrapped__"):
                unwrap = unwrap.__wrapped__
                continue
            if isinstance(unwrap, functools.partial):
                unwrap = unwrap.func
                continue
            break
        if hasattr(unwrap, "__globals__"):
            obj_globals = unwrap.__globals__

    if globals is None:
        globals = obj_globals
    if locals is None:
        locals = obj_locals

    # "Inject" type parameters into the local namespace
    # (unless they are shadowed by assignments *in* the local namespace),
    # as a way of emulating annotation scopes when calling `eval()`
    if type_params := getattr(obj, "__type_params__", ()):
        if locals is None:
            locals = {}
        locals = {param.__name__: param for param in type_params} | locals

    return_value = {
        key: value if not isinstance(value, str) else eval(value, globals, locals)
        for key, value in ann.items()
    }
    return return_value
